{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "import d2lzh as d2l\n",
    "from mxnet import nd"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "读取数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 256\n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "定义参数模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_inputs, num_outputs, num_hiddens = 784, 10, 1000\n",
    "W1 = nd.random.normal(scale=0.005, shape=(num_inputs, num_hiddens))\n",
    "b1 = nd.zeros(num_hiddens)\n",
    "W2 = nd.random.normal(scale=0.005, shape=(num_hiddens, num_outputs))\n",
    "b2 = nd.zeros(num_outputs)\n",
    "params = [W1, b1, W2, b2]\n",
    "\n",
    "for param in params:\n",
    "    param.attach_grad()"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "定义激活函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [],
   "source": [
    "def relu(X):\n",
    "    return nd.maximum(X, 0)"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "定义模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [],
   "source": [
    "def net(X):\n",
    "    X = X.reshape((-1, num_inputs))\n",
    "    H = relu(nd.dot(X, W1) + b1)\n",
    "    return nd.dot(H, W2) + b2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义损失函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [],
   "source": [
    "from mxnet.gluon import loss as gloss \n",
    "loss = gloss.SoftmaxCrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "epoch 1, loss 1.1356, train acc 0.584, test acc 0.746\n",
      "epoch 2, loss 0.6019, train acc 0.765, test acc 0.811\n",
      "epoch 3, loss 0.5307, train acc 0.799, test acc 0.830\n",
      "epoch 4, loss 0.4681, train acc 0.826, test acc 0.851\n",
      "epoch 5, loss 0.4193, train acc 0.845, test acc 0.852\n",
      "epoch 6, loss 0.3966, train acc 0.853, test acc 0.868\n",
      "epoch 7, loss 0.3759, train acc 0.860, test acc 0.875\n",
      "epoch 8, loss 0.3583, train acc 0.868, test acc 0.860\n",
      "epoch 9, loss 0.7906, train acc 0.743, test acc 0.834\n",
      "epoch 10, loss 0.4492, train acc 0.833, test acc 0.850\n",
      "epoch 11, loss 0.4149, train acc 0.846, test acc 0.857\n",
      "epoch 12, loss 0.4044, train acc 0.851, test acc 0.869\n",
      "epoch 13, loss 0.3714, train acc 0.860, test acc 0.849\n",
      "epoch 14, loss 0.3530, train acc 0.870, test acc 0.873\n",
      "epoch 15, loss 0.3458, train acc 0.871, test acc 0.872\n",
      "epoch 16, loss 0.3297, train acc 0.877, test acc 0.877\n",
      "epoch 17, loss 0.3205, train acc 0.881, test acc 0.879\n",
      "epoch 18, loss 0.3123, train acc 0.885, test acc 0.868\n",
      "epoch 19, loss 0.3123, train acc 0.884, test acc 0.881\n",
      "epoch 20, loss 0.3030, train acc 0.887, test acc 0.884\n",
      "epoch 21, loss 0.2925, train acc 0.892, test acc 0.887\n",
      "epoch 22, loss 0.2875, train acc 0.894, test acc 0.873\n",
      "epoch 23, loss 0.2832, train acc 0.895, test acc 0.886\n",
      "epoch 24, loss 0.2802, train acc 0.896, test acc 0.884\n",
      "epoch 25, loss 0.2733, train acc 0.898, test acc 0.884\n",
      "epoch 26, loss 0.2666, train acc 0.901, test acc 0.888\n",
      "epoch 27, loss 0.2622, train acc 0.903, test acc 0.886\n",
      "epoch 28, loss 0.2645, train acc 0.901, test acc 0.881\n",
      "epoch 29, loss 0.2570, train acc 0.904, test acc 0.888\n",
      "epoch 30, loss 0.2516, train acc 0.906, test acc 0.879\n",
      "epoch 31, loss 0.2489, train acc 0.907, test acc 0.886\n",
      "epoch 32, loss 0.2605, train acc 0.903, test acc 0.886\n",
      "epoch 33, loss 0.2462, train acc 0.909, test acc 0.883\n",
      "epoch 34, loss 0.2396, train acc 0.911, test acc 0.889\n",
      "epoch 35, loss 0.2357, train acc 0.913, test acc 0.885\n",
      "epoch 36, loss 0.2312, train acc 0.914, test acc 0.892\n",
      "epoch 37, loss 0.2327, train acc 0.913, test acc 0.890\n",
      "epoch 38, loss 0.2266, train acc 0.914, test acc 0.888\n",
      "epoch 39, loss 0.2252, train acc 0.915, test acc 0.885\n",
      "epoch 40, loss 0.2207, train acc 0.917, test acc 0.891\n",
      "epoch 41, loss 0.2191, train acc 0.918, test acc 0.890\n",
      "epoch 42, loss 0.2186, train acc 0.918, test acc 0.886\n",
      "epoch 43, loss 0.2146, train acc 0.920, test acc 0.889\n",
      "epoch 44, loss 0.2120, train acc 0.920, test acc 0.890\n",
      "epoch 45, loss 0.2089, train acc 0.921, test acc 0.882\n",
      "epoch 46, loss 0.2060, train acc 0.923, test acc 0.885\n",
      "epoch 47, loss 0.2035, train acc 0.924, test acc 0.885\n",
      "epoch 48, loss 0.1987, train acc 0.926, test acc 0.892\n",
      "epoch 49, loss 0.1989, train acc 0.925, test acc 0.890\n",
      "epoch 50, loss 0.1948, train acc 0.928, test acc 0.888\n",
      "epoch 51, loss 0.1942, train acc 0.926, test acc 0.890\n",
      "epoch 52, loss 0.1934, train acc 0.928, test acc 0.889\n",
      "epoch 53, loss 0.1856, train acc 0.931, test acc 0.893\n",
      "epoch 54, loss 0.1883, train acc 0.930, test acc 0.887\n",
      "epoch 55, loss 0.1794, train acc 0.933, test acc 0.893\n",
      "epoch 56, loss 0.1816, train acc 0.932, test acc 0.890\n",
      "epoch 57, loss 0.1799, train acc 0.933, test acc 0.892\n",
      "epoch 58, loss 0.1762, train acc 0.935, test acc 0.890\n",
      "epoch 59, loss 0.1753, train acc 0.935, test acc 0.892\n",
      "epoch 60, loss 0.1740, train acc 0.935, test acc 0.893\n",
      "epoch 61, loss 0.1724, train acc 0.935, test acc 0.892\n",
      "epoch 62, loss 0.1684, train acc 0.937, test acc 0.891\n",
      "epoch 63, loss 0.1657, train acc 0.939, test acc 0.887\n",
      "epoch 64, loss 0.1652, train acc 0.937, test acc 0.889\n",
      "epoch 65, loss 0.1618, train acc 0.939, test acc 0.890\n",
      "epoch 66, loss 0.1595, train acc 0.941, test acc 0.885\n",
      "epoch 67, loss 0.1592, train acc 0.940, test acc 0.890\n",
      "epoch 68, loss 0.1544, train acc 0.943, test acc 0.889\n",
      "epoch 69, loss 0.1572, train acc 0.943, test acc 0.893\n",
      "epoch 70, loss 0.1531, train acc 0.944, test acc 0.893\n",
      "epoch 71, loss 0.1528, train acc 0.943, test acc 0.889\n",
      "epoch 72, loss 0.1560, train acc 0.943, test acc 0.888\n",
      "epoch 73, loss 0.1468, train acc 0.946, test acc 0.889\n",
      "epoch 74, loss 0.1474, train acc 0.945, test acc 0.894\n",
      "epoch 75, loss 0.1490, train acc 0.944, test acc 0.887\n",
      "epoch 76, loss 0.1461, train acc 0.945, test acc 0.894\n",
      "epoch 77, loss 0.1440, train acc 0.946, test acc 0.889\n",
      "epoch 78, loss 0.1403, train acc 0.948, test acc 0.895\n",
      "epoch 79, loss 0.1377, train acc 0.950, test acc 0.887\n",
      "epoch 80, loss 0.1414, train acc 0.947, test acc 0.891\n",
      "epoch 81, loss 0.1380, train acc 0.948, test acc 0.892\n",
      "epoch 82, loss 0.1353, train acc 0.949, test acc 0.884\n",
      "epoch 83, loss 0.1385, train acc 0.950, test acc 0.886\n",
      "epoch 84, loss 0.1330, train acc 0.951, test acc 0.889\n",
      "epoch 85, loss 0.1332, train acc 0.951, test acc 0.890\n",
      "epoch 86, loss 0.1296, train acc 0.952, test acc 0.893\n",
      "epoch 87, loss 0.1339, train acc 0.951, test acc 0.893\n",
      "epoch 88, loss 0.1271, train acc 0.953, test acc 0.895\n",
      "epoch 89, loss 0.1285, train acc 0.952, test acc 0.892\n",
      "epoch 90, loss 0.1230, train acc 0.955, test acc 0.890\n",
      "epoch 91, loss 0.1231, train acc 0.954, test acc 0.892\n",
      "epoch 92, loss 0.1266, train acc 0.953, test acc 0.887\n",
      "epoch 93, loss 0.1201, train acc 0.955, test acc 0.893\n",
      "epoch 94, loss 0.1195, train acc 0.956, test acc 0.895\n",
      "epoch 95, loss 0.1149, train acc 0.958, test acc 0.883\n",
      "epoch 96, loss 0.1176, train acc 0.956, test acc 0.892\n",
      "epoch 97, loss 0.1212, train acc 0.955, test acc 0.890\n",
      "epoch 98, loss 0.1195, train acc 0.956, test acc 0.881\n",
      "epoch 99, loss 0.1158, train acc 0.957, test acc 0.892\n",
      "epoch 100, loss 0.1137, train acc 0.958, test acc 0.885\n"
     ]
    }
   ],
   "source": [
    "num_epochs, lr = 100, 0.8\n",
    "d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size,\n",
    "             [W1, b1, W2, b2], lr)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
